{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58b27c92-44ac-479d-a737-7208be472aec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install llama-index modelscope"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "99b63427-5e6b-4559-9a97-7e6ae39c1d66",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-03-24T08:15:56.090361Z",
     "iopub.status.busy": "2024-03-24T08:15:56.090077Z",
     "iopub.status.idle": "2024-03-24T08:15:59.343513Z",
     "shell.execute_reply": "2024-03-24T08:15:59.343059Z",
     "shell.execute_reply.started": "2024-03-24T08:15:56.090344Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from llama_index.core import VectorStoreIndex, ServiceContext\n",
    "from llama_index.core.llms.callbacks import llm_completion_callback\n",
    "from llama_index.legacy.embeddings import HuggingFaceEmbedding\n",
    "from llama_index.legacy.llms import (CustomLLM, CompletionResponse, CompletionResponseGen, LLMMetadata)\n",
    "from llama_index.legacy.embeddings import HuggingFaceEmbedding\n",
    "\n",
    "from typing import Any\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel\n",
    "\n",
    "from modelscope import snapshot_download\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4fec5fd6-0e06-4824-95ed-2bb06f3f549f",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2024-03-24T08:16:05.517353Z",
     "iopub.status.busy": "2024-03-24T08:16:05.516808Z",
     "iopub.status.idle": "2024-03-24T08:18:26.912109Z",
     "shell.execute_reply": "2024-03-24T08:18:26.911633Z",
     "shell.execute_reply.started": "2024-03-24T08:16:05.517333Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: Pooling config file not found; pooling mode is defaulted to 'cls'.\n"
     ]
    }
   ],
   "source": [
    "# load embedding\n",
    "emb_path = snapshot_download('jieshenai/m3e-base')\n",
    "embedding_model = HuggingFaceEmbedding(emb_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f035abeb-5f46-4043-ae81-16e084667014",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-03-24T08:23:44.609749Z",
     "iopub.status.busy": "2024-03-24T08:23:44.609424Z",
     "iopub.status.idle": "2024-03-24T08:23:52.792689Z",
     "shell.execute_reply": "2024-03-24T08:23:52.792175Z",
     "shell.execute_reply.started": "2024-03-24T08:23:44.609730Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]/opt/conda/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
      "  return self.fget.__get__(instance, owner)()\n",
      "Loading checkpoint shards: 100%|██████████| 7/7 [00:06<00:00,  1.07it/s]\n"
     ]
    }
   ],
   "source": [
    "model_name = \"chatglm3-6b\"\n",
    "model_path = snapshot_download('ZhipuAI/chatglm3-6b')\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n",
    "model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()\n",
    "model = model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "495c253e-1018-49c4-a4e0-f4daf632a67b",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2024-03-24T08:30:58.678989Z",
     "iopub.status.busy": "2024-03-24T08:30:58.678653Z",
     "iopub.status.idle": "2024-03-24T08:30:58.684032Z",
     "shell.execute_reply": "2024-03-24T08:30:58.683541Z",
     "shell.execute_reply.started": "2024-03-24T08:30:58.678969Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# set context window size\n",
    "context_window = 2048\n",
    "# set number of output tokens\n",
    "num_output = 256\n",
    "\n",
    "\n",
    "class ChatGML(CustomLLM):\n",
    "    @property\n",
    "    def metadata(self) -> LLMMetadata:\n",
    "        \"\"\"Get LLM metadata.\"\"\"\n",
    "        return LLMMetadata(\n",
    "            context_window=context_window,\n",
    "            num_output=num_output,\n",
    "            model_name=model_name,\n",
    "        )\n",
    "\n",
    "    @llm_completion_callback()\n",
    "    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:\n",
    "        # prompt_length = len(prompt)\n",
    "\n",
    "        # only return newly generated tokens\n",
    "        text,_ = model.chat(tokenizer, prompt, history=[])\n",
    "        return CompletionResponse(text=text)\n",
    "\n",
    "    @llm_completion_callback()\n",
    "    def stream_complete(\n",
    "        self, prompt: str, **kwargs: Any\n",
    "    ) -> CompletionResponseGen:\n",
    "        raise NotImplementedError()\n",
    "\n",
    "llm_model = ChatGML()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e78a99bc-5ff1-41ca-aefd-380416eb3313",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2024-03-24T08:25:35.480885Z",
     "iopub.status.busy": "2024-03-24T08:25:35.480562Z",
     "iopub.status.idle": "2024-03-24T08:25:35.518308Z",
     "shell.execute_reply": "2024-03-24T08:25:35.517856Z",
     "shell.execute_reply.started": "2024-03-24T08:25:35.480865Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from llama_index.core import SimpleDirectoryReader\n",
    "\n",
    "documents = SimpleDirectoryReader('data').load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4381ba29-b0b9-46fb-82d2-dfdf22404664",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import ServiceContext"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ffb36507-248b-44be-89f8-1e6d3e39f4b4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-03-24T08:31:02.252986Z",
     "iopub.status.busy": "2024-03-24T08:31:02.252674Z",
     "iopub.status.idle": "2024-03-24T08:31:02.421199Z",
     "shell.execute_reply": "2024-03-24T08:31:02.420722Z",
     "shell.execute_reply.started": "2024-03-24T08:31:02.252967Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_280/1247180598.py:1: DeprecationWarning: Call to deprecated class method from_defaults. (ServiceContext is deprecated, please use `llama_index.settings.Settings` instead.) -- Deprecated since version 0.10.0.\n",
      "  service_context = ServiceContext.from_defaults(llm=llm_model, embed_model=embedding_model)\n"
     ]
    }
   ],
   "source": [
    "service_context = ServiceContext.from_defaults(llm=llm_model, embed_model=embedding_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "f6f315ea-8324-4ff8-a635-7d8ac3430ac2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-03-24T08:31:07.334934Z",
     "iopub.status.busy": "2024-03-24T08:31:07.334608Z",
     "iopub.status.idle": "2024-03-24T08:31:07.338270Z",
     "shell.execute_reply": "2024-03-24T08:31:07.337820Z",
     "shell.execute_reply.started": "2024-03-24T08:31:07.334915Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ServiceContext(llm_predictor=LLMPredictor(system_prompt=None, query_wrapper_prompt=None, pydantic_program_mode=<PydanticProgramMode.DEFAULT: 'default'>), prompt_helper=PromptHelper(context_window=2048, num_output=256, chunk_overlap_ratio=0.1, chunk_size_limit=None, separator=' '), embed_model=HuggingFaceEmbedding(model_name='models/m3e-base', embed_batch_size=10, callback_manager=<llama_index.core.callbacks.base.CallbackManager object at 0x7efc9753d2d0>, tokenizer_name='models/m3e-base', max_length=512, pooling=<Pooling.CLS: 'cls'>, normalize=True, query_instruction=None, text_instruction=None, cache_folder=None), transformations=[SentenceSplitter(include_metadata=True, include_prev_next_rel=True, callback_manager=<llama_index.core.callbacks.base.CallbackManager object at 0x7efc9753d2d0>, id_func=<function default_id_func at 0x7efd8533f2e0>, chunk_size=1024, chunk_overlap=200, separator=' ', paragraph_separator='\\n\\n\\n', secondary_chunking_regex='[^,.;。？！]+[,.;。？！]?')], llama_logger=<llama_index.core.service_context_elements.llama_logger.LlamaLogger object at 0x7efea536ff40>, callback_manager=<llama_index.core.callbacks.base.CallbackManager object at 0x7efc9753d2d0>)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "service_context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "4c79aaf2-609f-428a-b636-bca2cf86599c",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2024-03-24T08:32:20.570957Z",
     "iopub.status.busy": "2024-03-24T08:32:20.570630Z",
     "iopub.status.idle": "2024-03-24T08:32:21.325928Z",
     "shell.execute_reply": "2024-03-24T08:32:21.325435Z",
     "shell.execute_reply.started": "2024-03-24T08:32:20.570937Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "少女送给小风的神奇礼物是一把能够召唤风的力量的魔法扇。\n"
     ]
    }
   ],
   "source": [
    "# create index\n",
    "index = VectorStoreIndex.from_documents(documents, service_context=service_context)\n",
    "\n",
    "# query engine\n",
    "query_engine = index.as_query_engine()\n",
    "\n",
    "# query\n",
    "response = query_engine.query(\"少女感激不已，送给小风一件神奇的礼物是什么？\")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bb2f356-4aec-4fcd-9af8-a471cc08c3b4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
